from econml.deepiv import DeepIVEstimator
from econml.bootstrap import BootstrapEstimator
import keras
import numpy as np
import csv
from itertools import product
from sklearn.linear_model import (Lasso, LassoCV, LogisticRegression,
                                  LogisticRegressionCV, LinearRegression,
                                  MultiTaskElasticNet, MultiTaskElasticNetCV)
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import train_test_split
import pandas as pd
import matplotlib.pyplot as plt
# load data
import pandas as pd
import random

data = pd.DataFrame(pd.read_csv('infos_feature_new.csv'))

# header
header = pd.DataFrame(pd.read_csv('infos_feature_new.csv', nrows=0))
for i, j in enumerate(header):
    print(i, j)

x = np.array(pd.concat([data.iloc[:, 3:13], data.iloc[:, 23:]], axis=1))
z = data.iloc[:, 15]
t = data.iloc[:, 14]
y = data.iloc[:, 2]

random.seed(777)

treatment_model = keras.Sequential([keras.layers.Dense(128, activation='relu', input_shape=(267,)),
                                    keras.layers.BatchNormalization(),
                                    keras.layers.Dropout(0.2),
                                    keras.layers.Dense(64, activation='relu'),
                                    keras.layers.BatchNormalization(),
                                    keras.layers.Dropout(0.2),
                                    keras.layers.Dense(32, activation='tanh'),
                                    keras.layers.BatchNormalization(),
                                    keras.layers.Dropout(0.2)
                                    ])

response_model = keras.Sequential([keras.layers.Dense(128, activation='tanh', input_shape=(267,)),
                                   keras.layers.BatchNormalization(),
                                   keras.layers.Dropout(0.2),
                                   keras.layers.Dense(64, activation='relu'),
                                   keras.layers.BatchNormalization(),
                                   keras.layers.Dropout(0.2),
                                   keras.layers.Dense(32, activation='relu'),
                                   keras.layers.BatchNormalization(),
                                   keras.layers.Dropout(0.2),
                                   keras.layers.Dense(1)
                                   ])

keras_fit_options = {"epochs": 100,
                     "validation_split": 0.1,
                     "batch_size": 256,
                     "callbacks": [keras.callbacks.EarlyStopping(patience=2, restore_best_weights=True)]
                     }

deepIvEst = DeepIVEstimator(n_components=1000,  # number of gaussians in our mixture density network
                            m=lambda z, x: treatment_model(keras.layers.concatenate([z, x], axis=1)),  # treatment model
                            h=lambda t, x: response_model(keras.layers.concatenate([t, x], axis=1)),  # response model
                            n_samples=1,
                            use_upper_bound_loss=False,
                            n_gradient_samples=1,
                            optimizer='adam',
                            first_stage_options=keras_fit_options,
                            second_stage_options=keras_fit_options)
boot_est = BootstrapEstimator(deepIvEst, n_bootstrap_samples=1000, n_jobs=1)
boot_est.fit(Y=y, T=t, X=x, Z=z)

# average
t1 = np.linspace(0, 20, num=100)
y_ate_1 = []
upper1 = []
lower1 = []
for t11 in t1:
    y_ate_1.append(np.mean(boot_est.marginal_effect(np.array([t11] * x[:].shape[0]), x[:])))
    elas_int = boot_est.marginal_effect_interval(np.array([t11] * x[:].shape[0]), x[:], lower=2.5, upper=97.5)
    lower1.append(np.mean(elas_int[0]))
    upper1.append(np.mean(elas_int[1]))

plt.figure()
plt.plot(t1, y_ate_1, label='pred elast')
plt.fill_between(t1, lower1, upper1, alpha=.3, label="95% CI")
plt.xlabel('log(follower)')
plt.ylabel('elasticities')
plt.legend()
plt.show()
plt.savefig('image_ate/avg_ate.jpg')
y_pred = 0
lower2 = np.array([0.0] * 100)
upper2 = np.array([0.0] * 100)
for x_m in x[:]:
    x2 = np.array([x_m for i in range(100)])
    y_pred += boot_est.predict(t1, x2) / len(x[:])
    elas_int = boot_est.predict_interval(t1, x2, lower=2.5, upper=97.5)
    lower2 += np.array(elas_int[0]) / len(x[:])
    upper2 += np.array(elas_int[1]) / len(x[:])

plt.figure()
plt.plot(t1, y_pred, label='pred impr')
plt.fill_between(t1, lower2, upper2, alpha=.3, label="95% CI")
plt.xlabel('log(follower)')
plt.ylabel('log(impr)')
plt.legend()
plt.show()
plt.savefig('image_ate/avg_pred.jpg')

avg_area = 0
for i in range(len(y_ate_1) - 1):
    avg_area += (y_ate_1[i] + y_ate_1[i + 1]) * 0.2 / 2
avg_area = avg_area / 20

# heterogeneous: engagement
data_ent = data.loc[data['entertaining'] == 1]
data_inf = data.loc[data['informational'] == 1]
data_soc = data.loc[data['socializing'] == 1]
x_ent = np.array(pd.concat([data_ent.iloc[:, 3:13], data_ent.iloc[:, 23:]], axis=1))
z_ent = data_ent.iloc[:, 15]
t_ent = data_ent.iloc[:, 14]
y_ent = data_ent.iloc[:, 2]
x_inf = np.array(pd.concat([data_inf.iloc[:, 3:13], data_inf.iloc[:, 23:]], axis=1))
z_inf = data_inf.iloc[:, 15]
t_inf = data_inf.iloc[:, 14]
y_inf = data_inf.iloc[:, 2]
x_soc = np.array(pd.concat([data_soc.iloc[:, 3:13], data_soc.iloc[:, 23:]], axis=1))
z_soc = data_soc.iloc[:, 15]
t_soc = data_soc.iloc[:, 14]
y_soc = data_soc.iloc[:, 2]
y_ate_ent = []
upper_ent = []
lower_ent = []
y_ate_inf = []
upper_inf = []
lower_inf = []
y_ate_soc = []
upper_soc = []
lower_soc = []
for t11 in t1:
    y_ate_ent.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_ent.shape[0]), x_ent)))
    elas_int_ent = boot_est.marginal_effect_interval(np.array([t11] * x_ent.shape[0]), x_ent, lower=2.5, upper=97.5)
    lower_ent.append(np.mean(elas_int_ent[0]))
    upper_ent.append(np.mean(elas_int_ent[1]))
    y_ate_inf.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_inf.shape[0]), x_inf)))
    elas_int_inf = boot_est.marginal_effect_interval(np.array([t11] * x_inf.shape[0]), x_inf, lower=2.5, upper=97.5)
    lower_inf.append(np.mean(elas_int_inf[0]))
    upper_inf.append(np.mean(elas_int_inf[1]))
    y_ate_soc.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_soc.shape[0]), x_soc)))
    elas_int_soc = boot_est.marginal_effect_interval(np.array([t11] * x_soc.shape[0]), x_soc, lower=2.5, upper=97.5)
    lower_soc.append(np.mean(elas_int_soc[0]))
    upper_soc.append(np.mean(elas_int_soc[1]))

plt.figure()
plt.plot(t1, y_ate_ent, label="Entertaining", linestyle='--', color='palevioletred', linewidth=2.5)
plt.fill_between(t1, lower_ent, upper_ent, color='lightpink', alpha=.2)
plt.plot(t1, y_ate_inf, label="Informational", linestyle=':', color='darkturquoise', linewidth=2.5)
plt.fill_between(t1, lower_inf, upper_inf, alpha=.2, color='mediumturquoise')
plt.plot(t1, y_ate_soc, label="Socializing", color='orchid', linewidth=2.5)
plt.fill_between(t1, lower_soc, upper_soc, alpha=.2, color='thistle')
plt.xlabel('log(follower)')
plt.ylabel('elasticities')
plt.legend()
plt.show()
plt.savefig('image_ate/eng_ate.jpg')

avg_area_ent = 0
avg_area_inf = 0
avg_area_soc = 0
for i in range(len(y_ate_ent) - 1):
    avg_area_ent += (y_ate_ent[i] + y_ate_ent[i + 1]) * 0.2 / 2
avg_area_ent = avg_area_ent / 20
for i in range(len(y_ate_inf) - 1):
    avg_area_inf += (y_ate_inf[i] + y_ate_inf[i + 1]) * 0.2 / 2
avg_area_inf = avg_area_inf / 20
for i in range(len(y_ate_soc) - 1):
    avg_area_soc += (y_ate_soc[i] + y_ate_soc[i + 1]) * 0.2 / 2
avg_area_soc = avg_area_soc / 20

# heterogeneous: topic
data_life = data.loc[data['topics'] == 1]
data_holi = data.loc[data['topics'] == 2]
data_skill = data.loc[data['topics'] == 3]
data_food = data.loc[data['topics'] == 4]
data_gam = data.loc[data['topics'] == 5]

x_life = pd.concat([data_life.iloc[:, 3:13], data_life.iloc[:, 23:]], axis=1)
z_life = data_life.iloc[:, 15]
t_life = data_life.iloc[:, 14]
y_life = data_life.iloc[:, 2]
x_holi = pd.concat([data_holi.iloc[:, 3:13], data_holi.iloc[:, 23:]], axis=1)
z_holi = data_holi.iloc[:, 15]
t_holi = data_holi.iloc[:, 14]
y_holi = data_holi.iloc[:, 2]
x_skill = pd.concat([data_skill.iloc[:, 3:13], data_skill.iloc[:, 23:]], axis=1)
z_skill = data_skill.iloc[:, 15]
t_skill = data_skill.iloc[:, 14]
y_skill = data_skill.iloc[:, 2]
x_food = pd.concat([data_food.iloc[:, 3:13], data_food.iloc[:, 23:]], axis=1)
z_food = data_food.iloc[:, 15]
t_food = data_food.iloc[:, 14]
y_food = data_food.iloc[:, 2]
x_gam = pd.concat([data_gam.iloc[:, 3:13], data_gam.iloc[:, 23:]], axis=1)
z_gam = data_gam.iloc[:, 15]
t_gam = data_gam.iloc[:, 14]
y_gam = data_gam.iloc[:, 2]
y_ate_life = []
upper_life = []
lower_life = []
y_ate_holi = []
upper_holi = []
lower_holi = []
y_ate_skill = []
upper_skill = []
lower_skill = []
y_ate_food = []
upper_food = []
lower_food = []
y_ate_gam = []
upper_gam = []
lower_gam = []
for t11 in t1:
    y_ate_life.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_life.shape[0]), x_life)))
    elas_int_life = boot_est.marginal_effect_interval(np.array([t11] * x_life.shape[0]), x_life, lower=2.5, upper=97.5)
    lower_life.append(np.mean(elas_int_life[0]))
    upper_life.append(np.mean(elas_int_life[1]))
    y_ate_holi.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_holi.shape[0]), x_holi)))
    elas_int_holi = boot_est.marginal_effect_interval(np.array([t11] * x_holi.shape[0]), x_holi, lower=2.5, upper=97.5)
    lower_holi.append(np.mean(elas_int_holi[0]))
    upper_holi.append(np.mean(elas_int_holi[1]))
    y_ate_skill.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_skill.shape[0]), x_skill)))
    elas_int_skill = boot_est.marginal_effect_interval(np.array([t11] * x_skill.shape[0]), x_skill, lower=2.5,
                                                       upper=97.5)
    lower_skill.append(np.mean(elas_int_skill[0]))
    upper_skill.append(np.mean(elas_int_skill[1]))
    y_ate_food.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_food.shape[0]), x_food)))
    elas_int_food = boot_est.marginal_effect_interval(np.array([t11] * x_food.shape[0]), x_food, lower=2.5, upper=97.5)
    lower_food.append(np.mean(elas_int_food[0]))
    upper_food.append(np.mean(elas_int_food[1]))
    y_ate_gam.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_gam.shape[0]), x_gam)))
    elas_int_gam = boot_est.marginal_effect_interval(np.array([t11] * x_gam.shape[0]), x_gam, lower=2.5, upper=97.5)
    lower_gam.append(np.mean(elas_int_gam[0]))
    upper_gam.append(np.mean(elas_int_gam[1]))

plt.figure()
plt.plot(t1, y_ate_life, label="Life", color='cadetblue', linestyle='dashed', alpha=0.6, linewidth=2)
plt.fill_between(t1, lower_life, upper_life, color='powderblue', alpha=.2)
plt.plot(t1, y_ate_holi, label="Holidays", color='yellowgreen', linewidth=2)
plt.fill_between(t1, lower_holi, upper_holi, alpha=.2, color='yellowgreen', linewidth=2)
plt.plot(t1, y_ate_skill, label="Skills", color='darkgrey', linestyle='dotted', linewidth=2)
plt.fill_between(t1, lower_skill, upper_skill, alpha=.2, color='silver')
plt.plot(t1, y_ate_food, label="Food", color='goldenrod', linestyle='dashdot', linewidth=1.8)
plt.fill_between(t1, lower_food, upper_food, color='goldenrod', alpha=.2)
plt.plot(t1, y_ate_gam, label="Gaming", color='maroon', linestyle=(0, (3, 1, 1, 1)), alpha=0.6, linewidth=2)
plt.fill_between(t1, lower_gam, upper_gam, color='lightpink', alpha=.2)
plt.xlabel('log(follower)')
plt.ylabel('elasticities')
plt.legend()
plt.show()
plt.savefig('image_ate/topic_ate.jpg')

avg_area_life = 0
avg_area_holi = 0
avg_area_skill = 0
avg_area_food = 0
avg_area_gam = 0
for i in range(len(y_ate_life) - 1):
    avg_area_life += (y_ate_life[i] + y_ate_life[i + 1]) * 0.2 / 2
avg_area_life = avg_area_life / 20
for i in range(len(y_ate_holi) - 1):
    avg_area_holi += (y_ate_holi[i] + y_ate_holi[i + 1]) * 0.2 / 2
avg_area_holi = avg_area_holi / 20
for i in range(len(y_ate_skill) - 1):
    avg_area_skill += (y_ate_skill[i] + y_ate_skill[i + 1]) * 0.2 / 2
avg_area_skill = avg_area_skill / 20
for i in range(len(y_ate_food) - 1):
    avg_area_food += (y_ate_food[i] + y_ate_food[i + 1]) * 0.2 / 2
avg_area_food = avg_area_food / 20
for i in range(len(y_ate_gam) - 1):
    avg_area_gam += (y_ate_gam[i] + y_ate_gam[i + 1]) * 0.2 / 2
avg_area_gam = avg_area_gam / 20

# sponsored
sphts = ['unwrapthedeals', 'whatsyourpower', 'videosnapchallenge', 'getcrocd', 'calistarchallenge', 'handwashchallenge',
         'morehappydenimdance', 'scoobdance', 'asosfashunweek', 'thesplashdance', 'readysetgo', 'letsfaceit',
         'dopacsun',
         'merrybossmas', 'moodflip', 'thisisbliss', 'upthebeat', 'katespadenyhappydance', 'perfectasIam',
         'expressieyourself',
         'closeyourrings', 'itwasntme', 'heinzhalloween', 'strictlycurl', 'monclerbubbleup', 'showupshowoff',
         'gotmilkchallenge',
         'goforthehandful', 'micellarrewind', 'cancelthenoise']

for spht in sphts[:1]:
    data_sp = data.loc[data['hashtag_name'] == spht]
    x_sp = np.array(pd.concat([data_sp.iloc[0:6, 3:13], data_sp.iloc[0:6, 23:]], axis=1))
    z_sp = data_sp.iloc[0:6, 15]
    t1 = np.linspace(0, 20, num=100)
    y_sp = data_sp.iloc[0:6, 2]
    y_ate_sp = []
    upper_sp = []
    lower_sp = []

    for t11 in t1:
        y_ate_sp.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_sp.shape[0]), x_sp)))
        elas_sp = boot_est.marginal_effect_interval(np.array([t11] * x_sp.shape[0]), x_sp, lower=2.5, upper=97.5)
        lower_sp.append(np.mean(elas_sp[0]))
        upper_sp.append(np.mean(elas_sp[1]))

    plt.figure()
    plt.plot(t1, y_ate_sp, label="pred_elast")
    plt.fill_between(t1, lower_sp, upper_sp, label="95% CI", alpha=.3)
    plt.xlabel('log(follower)')
    plt.ylabel('elasticities')
    plt.legend()
    #plt.show()
    plt.savefig("image_ate/Hashtag_" + spht + '_ate.jpg')

    y_pred_sp = 0
    lower2_sp = np.array([0.0] * 100)
    upper2_sp = np.array([0.0] * 100)
    for x_m in x_sp:
        x2 = np.array([x_m for i in range(100)])
        y_pred += boot_est.predict(t1, x2) / len(x_sp)
        elas_int = boot_est.predict_interval(t1, x2, lower=2.5, upper=97.5)
        lower2 += np.array(elas_int[0]) / len(x_sp)
        upper2 += np.array(elas_int[1]) / len(x_sp)

    # plt.figure()
    # plt.plot(t1,y_pred_sp,label="pred elast")
    # plt.fill_between(t1, lower2_sp, upper2_sp, alpha=.3, label="95% CI")
    # plt.xlabel('log(follower)')
    # plt.ylabel('log(impr)')
    # plt.legend()
    # plt.show()
    # plt.savefig("image_pred/Hashtag_"+spht+'_pred.jpg')

# case study: profit maximization for walmart
t1 = np.linspace(0, 20, num=100)
data_walmart = data.loc[data['hashtag_name'] == 'unwrapthedeals']
x_walmart = np.array(pd.concat([data_walmart.iloc[0:6, 3:13], data_walmart.iloc[0:6, 23:]], axis=1))
z_walmart = data_walmart.iloc[0:6, 15]
y_sp = data_walmart.iloc[0:6, 2]
y_ate_1 = []
upper1 = []
lower1 = []
for t11 in t1:
    y_ate_1.append(np.mean(boot_est.marginal_effect(np.array([t11] * x_walmart.shape[0]), x_walmart)))
    elas_sp = boot_est.marginal_effect_interval(np.array([t11] * x_walmart.shape[0]), x_walmart, lower=2.5, upper=97.5)
    lower1.append(np.mean(elas_sp[0]))
    upper1.append(np.mean(elas_sp[1]))

y_pred_walmart = 0
lower2_walmart = np.array([0.0] * 100)
upper2_walmart = np.array([0.0] * 100)
for x_m in x_walmart:
    x2 = np.array([x_m for i in range(100)])
    y_pred_walmart += boot_est.predict(t1, x2) / len(x_walmart)
    elas_int = boot_est.predict_interval(t1, x2, lower=2.5, upper=97.5)
    lower2_walmart += np.array(elas_int[0]) / len(x_walmart)
    upper2_walmart += np.array(elas_int[1]) / len(x_walmart)

with open('walmart.csv', 'w', encoding='utf-8', newline='\n') as fin:
    writer = csv.writer(fin)
    writer.writerow(y_ate_1)
    writer.writerow(y_pred_walmart)

ypw_mean_mr1 = []
for i in range(1, len(y_pred_walmart)):
    ypw_mean_mr1.append(
        0.02 * (np.exp(y_pred_walmart[i]) - np.exp(y_pred_walmart[i - 1])) / (np.exp(t1[i]) - np.exp(t1[i - 1])))

plt.figure()

plt.plot(t1[34:], ypw_mean_mr1[33:], label="MR(follower) under $0.02/impr", color="darkseagreen")
plt.plot(t1[34:], [5 / 1000] * len(t1[34:]), label="MC=$5/1000 follower", color="#DE3163")
plt.xlabel('log(follower)')
plt.ylabel('#impression')
plt.title("MR and MC of Followers")
plt.legend()
plt.show()
plt.savefig('image_ate/walmart_ate.jpg')

profit = []

for i in range(0, 100):
    profit.append((0.02 * np.exp(y_pred_walmart[i]) - np.exp(t1[i]) * 5 / 1000))
plt.figure()
plt.plot(t1[34:71], profit[34:71])
plt.xlabel('log(follower)')
plt.ylabel('profits($)')
plt.title("The Firm's Profits")

with open('heatmap.csv', 'w', encoding='utf-8', newline='\n') as fw:
    writer = csv.writer(fw)
    header = ['coefficient']
    mc = list(range(5, 101, 5))
    header.extend(mc)
    writer.writerow(header)
    coef = list(range(1, 21))
    coef = [0.01 * x for x in coef]
    for co in coef:
        indexs = [co]
        for mc1 in mc:
            yp1_mean_mr1 = []
            for i in range(1, len(y_pred_walmart)):
                yp1_mean_mr1.append(co * (np.exp(y_pred_walmart[i]) - np.exp(y_pred_walmart[i - 1])) / (
                            np.exp(t1[i]) - np.exp(t1[i - 1])))
            mindist = 100
            index = 1
            for i in range(33, 90):
                # if abs(yp1_mean_mr[i]-mc1/1000)<mindist:
                if abs(yp1_mean_mr1[i] - mc1 / 1000) < mindist and yp1_mean_mr1[i] - mc1 / 1000 > 0:
                    mindist = abs(yp1_mean_mr1[i] - mc1 / 1000)
                    index = i
            # print(mc1/1000,yp1_mean_mr[index],t1[index+1],)
            indexs.append(np.exp(t1[index + 1]) / 1000)
        writer.writerow(indexs)

# import seaborn as sns
# import matplotlib.pyplot as plt
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
# dicts=[]
# with open('heatmap.csv','r',encoding='utf-8',newline='\n') as fr:
#   reader=csv.reader(fr)
#   next(reader)
#   for line in reader:
#     dicts.append([int(x) for x in line[1:]])
#
# sns.set_theme()
# plt.figure()
# xaixs=[5,10,15,20,25,50,75,100]
# yaixs=list(range(2,21,2))
# yaixs=[0.01*x for x in yaixs]
#
# fig, ax = plt.subplots(figsize=(9,6))
# sns.heatmap(dicts, xticklabels=xaixs,yticklabels=yaixs, linewidths=0.5,cmap=sns.cubehelix_palette(20,light=0.95,dark=0.15),ax=ax,annot=True,fmt='g')
# plt.xticks(rotation=0)
# plt.yticks(rotation=0)
# plt.xlabel('$cost/1K follower')
# plt.ylabel('$value/impr')
# plt.title( "Optimal Follower Size in K" )
# plt.show()
